{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a6231285",
   "metadata": {},
   "source": [
    "# Press Release Chat Bot\n",
    "\n",
    "As part of this generative AI workflow, we create a NVIDIA PR chatbot that answers questions from the NVIDIA news and blogs from years of 2022 and 2023. For this, we have created a REST FastAPI server that wraps llama-index. The API server has two methods, ```upload_document``` and ```generate```. The ```upload_document``` method takes a document from the user's computer and uploads it to a Milvus vector database after splitting, chunking and embedding the document. The ```generate``` API method generates an answer from the provided prompt optionally sourcing information from a vector database. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c74eaf2",
   "metadata": {},
   "source": [
    "#### Step-1: Load the pdf files from the dataset folder.\n",
    "\n",
    "You can upload the pdf files containing the NVIDIA blogs to ```query:8081/uploadDocument``` API endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "263a7a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!unzip dataset.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2244b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import mimetypes\n",
    "\n",
    "def upload_document(file_path, url):\n",
    "    headers = {\n",
    "        'accept': 'application/json'\n",
    "    }\n",
    "    mime_type, _ = mimetypes.guess_type(file_path)\n",
    "    files = {\n",
    "        'file': (file_path, open(file_path, 'rb'), mime_type)\n",
    "    }\n",
    "    response = requests.post(url, headers=headers, files=files)\n",
    "\n",
    "    return response.text\n",
    "\n",
    "def upload_pdf_files(folder_path, upload_url, num_files):\n",
    "    i = 0\n",
    "    for files in os.listdir(folder_path):\n",
    "        _, ext = os.path.splitext(files)\n",
    "        # Ingest only pdf files\n",
    "        if ext.lower() == \".pdf\":\n",
    "            file_path = os.path.join(folder_path, files)\n",
    "            print(upload_document(file_path, upload_url))\n",
    "            i += 1\n",
    "            if i > num_files:\n",
    "                break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5c99ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "start_time = time.time()\n",
    "NUM_DOCS_TO_UPLOAD=100\n",
    "upload_pdf_files(\"dataset\", \"http://chain-server:8081/documents\", NUM_DOCS_TO_UPLOAD)\n",
    "print(f\"--- {time.time() - start_time} seconds ---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "830882ef",
   "metadata": {},
   "source": [
    "#### Step-2 : Ask a question without referring to the knowledge base\n",
    "Ask Tensorrt LLM llama-2 13B model a question about \"the nvidia grace superchip\" without seeking help from the vectordb/knowledge base by setting ```use_knowledge_base``` to ```false```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eb862fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import json\n",
    "\n",
    "data = {\n",
    " \"messages\": [\n",
    "    {\n",
    "      \"role\": \"user\",\n",
    "      \"content\": \"how many cores are on the nvidia grace superchip?\"\n",
    "    }\n",
    "  ],\n",
    "  \"use_knowledge_base\": \"false\",\n",
    "  \"max_tokens\": 256\n",
    "}\n",
    "\n",
    "url = \"http://chain-server:8081/generate\"\n",
    "\n",
    "start_time = time.time()\n",
    "with requests.post(url, stream=True, json=data) as req:\n",
    "    for chunk in req.iter_lines():\n",
    "        raw_resp = chunk.decode(\"UTF-8\")\n",
    "        if not raw_resp:\n",
    "            continue\n",
    "        resp_dict = json.loads(raw_resp[6:])\n",
    "        resp_choices = resp_dict.get(\"choices\", [])\n",
    "        if len(resp_choices):\n",
    "            resp_str = resp_choices[0].get(\"message\", {}).get(\"content\", \"\")\n",
    "            print(resp_str, end =\"\")\n",
    "\n",
    "print(f\"--- {time.time() - start_time} seconds ---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcf37ee9",
   "metadata": {},
   "source": [
    "Now ask it the same question by setting ```use_knowledge_base``` to ```true```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e904a658",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    " \"messages\": [\n",
    "    {\n",
    "      \"role\": \"user\",\n",
    "      \"content\": \"how many cores are on the nvidia grace superchip?\"\n",
    "    }\n",
    "  ],\n",
    "  \"use_knowledge_base\": \"true\",\n",
    "  \"max_tokens\": 50\n",
    "}\n",
    "\n",
    "url = \"http://chain-server:8081/generate\"\n",
    "\n",
    "start_time = time.time()\n",
    "tokens_generated = 0\n",
    "with requests.post(url, stream=True, json=data) as req:\n",
    "    for chunk in req.iter_lines():\n",
    "        raw_resp = chunk.decode(\"UTF-8\")\n",
    "        if not raw_resp:\n",
    "            continue\n",
    "        resp_dict = json.loads(raw_resp[6:])\n",
    "        resp_choices = resp_dict.get(\"choices\", [])\n",
    "        if len(resp_choices):\n",
    "            resp_str = resp_choices[0].get(\"message\", {}).get(\"content\", \"\")\n",
    "            print(resp_str, end =\"\")\n",
    "\n",
    "total_time = time.time() - start_time\n",
    "print(f\"\\n--- Generated {tokens_generated} tokens in {total_time} seconds ---\")\n",
    "print(f\"--- {tokens_generated/total_time} tokens/sec\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58954d15",
   "metadata": {},
   "source": [
    "#### Next steps\n",
    "\n",
    "We have setup a playground UI for you to upload files and get answers from, the UI is available on the same IP address as the notebooks: `host_ip:8090/converse`"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
